/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hive.ql.parse.relnodegen;

import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.metadata.RelColumnMapping;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.exec.ColumnInfo;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.optimizer.calcite.TraitsUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableFunctionScan;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.ql.parse.ASTErrorUtils;
import org.apache.hadoop.hive.ql.parse.ASTNode;
import org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer;
import org.apache.hadoop.hive.ql.parse.CalcitePlanner;
import org.apache.hadoop.hive.ql.parse.HiveParser;
import org.apache.hadoop.hive.ql.parse.RowResolver;
import org.apache.hadoop.hive.ql.parse.SemanticAnalyzer;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.UnparseTranslator;
import org.apache.hadoop.hive.ql.parse.type.FunctionHelper;
import org.apache.hadoop.hive.ql.parse.type.TypeCheckCtx;
import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * LateralViewPlan is a helper class holding the objects needed for generating a Calcite
 * plan from an ASTNode. The object is to be generated when a LATERAL_VIEW token is detected.
 * From the ASTNode and relevant input node information, the following objects are created:
 * * A HiveTableFunctionScan RelNode
 * * A RowResolver containing the row resolving information output from the RelNode
 * * The table alias for the table generated by the UDTF and Lateral View
 */
public class LateralViewPlan {
  protected static final Logger LOG = LoggerFactory.getLogger(LateralViewPlan.class.getName());

  // Only acceptable token types under the TOK_LATERAL_VIEW token.
  public static final ImmutableSet<Integer> TABLE_ALIAS_TOKEN_TYPES =
      ImmutableSet.of(HiveParser.TOK_SUBQUERY, HiveParser.TOK_TABREF, HiveParser.TOK_PTBLFUNCTION);

  // The RelNode created for this lateral view
  public final RelNode lateralViewRel;

  // The output RowResolver created for this lateral view.
  public final RowResolver outputRR;

  // The alias provided for the lateral table in the query
  public final String lateralTableAlias;

  private final RelOptCluster cluster;
  private final UnparseTranslator unparseTranslator;
  private final HiveConf conf;
  private final FunctionHelper functionHelper;

  public LateralViewPlan(ASTNode lateralView, RelOptCluster cluster, RelNode inputRel,
      RowResolver inputRR, UnparseTranslator unparseTranslator,
      HiveConf conf, FunctionHelper functionHelper
      ) throws SemanticException {
    // initialize global variables containing helper information
    this.cluster = cluster;
    this.unparseTranslator = unparseTranslator;
    this.conf = conf;
    this.functionHelper = functionHelper;

    // AST should have form of LATERAL_VIEW -> SELECT -> SELEXPR -> FUNCTION -> function info tree
    ASTNode selExprAST = (ASTNode) lateralView.getChild(0).getChild(0);
    ASTNode functionAST = (ASTNode) selExprAST.getChild(0);

    this.lateralTableAlias = getTableAliasFromASTNode(selExprAST);

    // The RexCall for the lateral function (e.g. lateral(inline(), $0, $1, ...)), where
    // the inputrefs are all retrieved from the input RelNode.
    RexCall udtfCall = getLateralFunction(functionAST, inputRR, inputRel);

    // Column aliases provided by the query.
    List<String> columnAliases = getColumnAliasesFromASTNode(selExprAST, udtfCall);

    this.outputRR = getOutputRR(inputRR, udtfCall, columnAliases, this.lateralTableAlias);

    RelDataType retType = getRetType(cluster, inputRel, udtfCall, columnAliases);

    this.lateralViewRel = HiveTableFunctionScan.create(cluster,
        TraitsUtil.getDefaultTraitSet(cluster), ImmutableList.of(inputRel), udtfCall,
        null, retType, createColumnMappings(inputRel));
  }

  public static void validateLateralView(ASTNode lateralView) throws SemanticException {
    if (lateralView.getChildCount() != 2) {
      throw new SemanticException("Token Lateral View contains " + lateralView.getChildCount() +
          " children.");
    }
    ASTNode next = (ASTNode) lateralView.getChild(1);
    if (!TABLE_ALIAS_TOKEN_TYPES.contains(next.getToken().getType()) &&
          HiveParser.TOK_LATERAL_VIEW != next.getToken().getType()) {
        throw new SemanticException(ASTErrorUtils.getMsg(
            ErrorMsg.LATERAL_VIEW_INVALID_CHILD.getMsg(), lateralView));
    }
  }

  private RexCall getLateralFunction(ASTNode functionAST, RowResolver inputRR, RelNode inputRel)
      throws SemanticException {
    RexCall udtfCall = getUDTFFunction(functionAST, inputRR);
    List<RexNode> operands = new ArrayList<>();
    operands.add(udtfCall);
    for (int i = 0; i < inputRel.getRowType().getFieldCount(); ++i) {
      RelDataType type = inputRel.getRowType().getFieldList().get(i).getType();
      operands.add(this.cluster.getRexBuilder().makeInputRef(type, i));
    }
    return (RexCall) this.cluster.getRexBuilder().makeCall(SqlStdOperatorTable.LATERAL, operands);
  }

  private RexCall getUDTFFunction(ASTNode functionAST, RowResolver inputRR)
      throws SemanticException {

    String functionName = functionAST.getChild(0).getText().toLowerCase();

    // create the RexNode operands for the UDTF RexCall
    List<RexNode> operandsForUDTF = getOperandsForUDTF(functionAST, inputRR);

    return this.functionHelper.getUDTFFunction(functionName, operandsForUDTF);
  }

  private String getTableAliasFromASTNode(ASTNode selExprClause) throws SemanticException {
    // loop through the AST and find the TOK_TABALIAS object
    for (Node obj : selExprClause.getChildren()) {
      ASTNode child = (ASTNode) obj;
      if (child.getToken().getType() == HiveParser.TOK_TABALIAS) {
        return BaseSemanticAnalyzer.unescapeIdentifier(child.getChild(0).getText().toLowerCase());
      }
    }

    // Parser enforces that table alias is added, but check again
    throw new SemanticException("Alias should be specified LVJ");
  }

  private List<String> getColumnAliasesFromASTNode(ASTNode selExprClause,
      RexCall udtfCall) throws SemanticException {
    Set<String> uniqueNames = new HashSet<>();
    List<String> colAliases = new ArrayList<>();
    for (Node obj : selExprClause.getChildren()) {
      ASTNode child = (ASTNode) obj;
      // Skip the token values.  The rest should be the identifier column aliases
      if (child.getToken().getType() == HiveParser.TOK_TABALIAS ||
          child.getToken().getType() == HiveParser.TOK_FUNCTION) {
        continue;
      }
      String colAlias = BaseSemanticAnalyzer.unescapeIdentifier(child.getText().toLowerCase());
      if (uniqueNames.contains(colAlias)) {
        // Column aliases defined by query for lateral view output are duplicated
        throw new SemanticException(ErrorMsg.COLUMN_ALIAS_ALREADY_EXISTS.getMsg(colAlias));
      }
      uniqueNames.add(colAlias);
      colAliases.add(colAlias);
    }

    // if no column aliases were provided, just retrieve them from the return type
    // of the udtf RexCall
    if (colAliases.isEmpty()) {
      colAliases.addAll(
          Lists.transform(udtfCall.getType().getFieldList(), RelDataTypeField::getName));
    }

    // Verify that there is an alias for all the columns returned by the udtf call.
    int udtfFieldCount = udtfCall.getType().getFieldCount();
    if (colAliases.size() != udtfFieldCount) {
      // Number of columns in the aliases does not match with number of columns
      // generated by the lateral view
      throw new SemanticException(ErrorMsg.UDTF_ALIAS_MISMATCH.getMsg(
          "expected " + udtfFieldCount + " aliases but got " + colAliases.size()));
    }

    return colAliases;
  }

  private List<RexNode> getOperandsForUDTF(ASTNode functionCall,
      RowResolver inputRR) throws SemanticException {
    List<RexNode> operands = new ArrayList<>();
    TypeCheckCtx tcCtx = new TypeCheckCtx(inputRR, this.cluster.getRexBuilder(), false, false);
    tcCtx.setUnparseTranslator(this.unparseTranslator);
    // Start at 1 because value 0 is the function name.  Use the CalcitePlanner.genRexNode
    // to retrieve the RexNode for all the function parameters.
    for (int i = 1; i < functionCall.getChildren().size(); ++i) {
      ASTNode functionParam = (ASTNode) functionCall.getChild(i);
      operands.add(CalcitePlanner.genRexNode(functionParam, inputRR, tcCtx, this.conf));
    }
    return operands;
  }

  private RowResolver getOutputRR(RowResolver inputRR, RexCall udtfCall,
      List<String> columnAliases, String lateralTableAlias) throws SemanticException {

    RowResolver localOutputRR = new RowResolver();

    // After calling RowResolver, outputRR will be mutated to contain the row resolver
    // fields
    if (!RowResolver.add(localOutputRR, inputRR)) {
      LOG.warn("Duplicates detected when adding columns to RR: see previous message");
    }

    // The RexNode return value for a udtf is always a struct.
    TypeInfo typeInfo = TypeConverter.convert(udtfCall.getType());
    Preconditions.checkState(typeInfo instanceof StructTypeInfo);

    StructTypeInfo typeInfos = (StructTypeInfo) typeInfo;
    // Match up the column alias with the return value of the udtf and
    // place in the outputRR
    for (int i = 0, j = 0; i < columnAliases.size(); i++) {
      String internalColName;
      do {
        internalColName = SemanticAnalyzer.getColumnInternalName(j++);
      } while (localOutputRR.getPosition(internalColName) != -1);
      localOutputRR.put(lateralTableAlias, columnAliases.get(i),
          new ColumnInfo(internalColName,  typeInfos.getAllStructFieldTypeInfos().get(i),
              lateralTableAlias, false));
    }
    return localOutputRR;
  }

  /**
   * get the return type for the lateral view. The return type will be a structure
   * where the first values are from the input RelNode. Then the output values from the
   * udtf will be added into the structure for the return type.
   */
  private RelDataType getRetType(RelOptCluster cluster, RelNode inputRel,
      RexNode udtfCall, List<String> columnAliases) {

    // initialize allDataTypes and allDataTypeNames from the fields in the inputRel
    List<RelDataType> allDataTypes = new ArrayList<>( Lists.transform(inputRel.getRowType().getFieldList(), RelDataTypeField::getType));
    List<String> allDataTypeNames = new ArrayList<>(
        Lists.transform(inputRel.getRowType().getFieldList(), RelDataTypeField::getName));

    RelDataType retType = udtfCall.getType();

    Preconditions.checkState(retType.isStruct());

    // Add the type names and values from the udtf into the lists that will make up the
    // return type. Names need to be unique so add the table prefix
    allDataTypes.addAll(Lists.transform(retType.getFieldList(), RelDataTypeField::getType));
    for (String s : columnAliases) {
      allDataTypeNames.add(lateralTableAlias + "." + s);
    }

    return cluster.getTypeFactory().createStructType(allDataTypes, allDataTypeNames);
  }

  private Set<RelColumnMapping> createColumnMappings(RelNode inputRel) {
    Set<RelColumnMapping> colMappings = new HashSet<>();
    for (int i = 0; i < inputRel.getRowType().getFieldCount(); ++i) {
      colMappings.add(new RelColumnMapping(i, 0, i, false));
    }
    return colMappings;
  }
}
